#!/usr/bin/env python
# coding: utf-8

# In[1]:


import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import pandas as pd
from pathos.multiprocessing import Pool
from linearmodels import OLS
import datetime
from jm.library.data_helper import filter_data, date_to_str, is_between, is_weakly_less_than
from jm.library.likelihood import Likelihood
import warnings
import scipy
from jm.library.additional_helpers import plot_over_time, payment_cols, priority_cols, action_cols, \
    relative_payment_cols, fwdiff_payment_cols, relative_fwdiff_payment_cols, action_nomedida_cols

import statsmodels.api as sm
import statsmodels.formula.api as smf
import patsy

def areg(formula,data=None,absorb=None,cluster=None):
    y,X = patsy.dmatrices(formula,data,return_type='dataframe')
    ybar = y.mean()
    y = y - y.groupby(data[absorb]).transform('mean') + ybar
    Xbar = X.mean()
    X = X - X.groupby(data[absorb]).transform('mean') + Xbar
    reg = sm.OLS(y,X)
    reg.df_resid -= (data[absorb].nunique() - 1)
    return reg.fit(cov_type='cluster',cov_kwds={'groups':data[cluster].values})

def get_cols_over_time(col_fmt='payments_by_{}'):
    return [col_fmt.format(date_to_str(d)) for d in Likelihood.event_dates]

def is_ever(df, cols, value, col_names=Likelihood.event_dates):
    global REF_DF
    global REF_COLS
    global REF_VALUE
    global REF_VALUE
    REF_DF = df
    REF_COLS = cols
    REF_VALUE = value
    with Pool() as pool:
        list_is = pool.map(
            lambda i: REF_DF.loc[:, REF_COLS[:i + 1]].apply(lambda r: np.isin(REF_VALUE, r), axis=1, raw=True),
            list(range(len(REF_COLS))))
    this_df = pd.concat(list_is, axis=1)
    this_df.columns = col_names
    return 1 * this_df

warnings.filterwarnings("ignore")

rcParams['figure.figsize'] = (13.0, 7.0)
rcParams['lines.linewidth'] = 3
rcParams['font.size'] = 18
rcParams['text.usetex'] = False
rcParams['text.latex.preamble'] = r'\usepackage{amsmath}' #for \text command

df_status = pd.read_csv('data/jesus_maria_status_deidentified.csv')
df_status['subsample'] = ~df_status['LLAVE_UN'].isna()

# is ever G1
is_ever_G1 = is_ever(df_status, priority_cols, 'G1')
is_ever_G1 = is_ever_G1['2021-09-06']
df_status['ever_G1'] = is_ever_G1
df_status['annual_total_due'] = 4 * df_status['total_due']

# Table OA.3
with open("figs/tableOA3.txt", "w") as text_file:
    for var in ['score_endo_covariates', 'score_exo_covariates',
                'annual_total_due', 'assignment_to_treatment', 'ever_G1']:
        print(var, file=text_file)
        print("Main sample - " + var + " (obs: " + str(len(df_status)) + ")", file=text_file)
        print(df_status[var].mean(), file=text_file)
        print(" ", file=text_file)
        print("Subsample - " + var + " (obs: " + str(len(df_status.loc[df_status['subsample']==1])) + ")", file=text_file)
        print(df_status.loc[df_status['subsample']==1, var].mean(), file=text_file)
        print(" ", file=text_file)
        print("TTest", file=text_file)
        print(scipy.stats.ttest_ind(df_status[var], df_status.loc[df_status['subsample']==1, var]), file=text_file)
        print(" ", file=text_file)

df_status = df_status.loc[df_status['subsample']==1]
# Table OA4
with open("figs/tableOA4.txt", "w") as text_file:
    for var in ['score_endo_covariates', 'score_exo_covariates', 'annual_total_due']:
        print(var, file=text_file)
        print("Treatment - " + var + " (obs: " + str(len(df_status.loc[df_status['assignment_to_treatment']==1])) + ")", file=text_file)
        print(df_status.loc[df_status['assignment_to_treatment']==1, var].mean(), file=text_file)
        print(" ", file=text_file)
        print("Control - " + var + " (obs: " + str(len(df_status.loc[df_status['assignment_to_treatment']==0])) + ")", file=text_file)
        print(df_status.loc[df_status['assignment_to_treatment']==0, var].mean(), file=text_file)
        print(" ", file=text_file)
        print("TTest", file=text_file)
        print(scipy.stats.ttest_ind(df_status.loc[df_status['assignment_to_treatment']==1, var],
                              df_status.loc[df_status['assignment_to_treatment']==0, var]), file=text_file)
        print(" ", file=text_file)


# Setup for Tables 3/4
df_status = df_status.loc[df_status['subsample']==1]

df_status['counts'] = df_status[['LLAVE_UN','id_scrambled']].groupby('LLAVE_UN').transform('count')
df_status['sizedenom'] = df_status['TotalPredios'] - 1
df_status['ratiototal'] = df_status['counts']/df_status['sizedenom']
df_status['treatcount'] = df_status[['LLAVE_UN','assignment_to_treatment']].groupby('LLAVE_UN').transform('sum')
df_status['ratiotreat'] = df_status['treatcount']/df_status['sizedenom']
df_status.loc[df_status['assignment_to_treatment']==1, 'ratiotreat'] = (
        df_status['treatcount']-1)/df_status['sizedenom']
df_status['control'] = 0
df_status.loc[df_status['assignment_to_treatment']==0, 'control'] = 1
df_status['controlcount'] = df_status[['LLAVE_UN','control']].groupby('LLAVE_UN').transform('sum')
df_status['ratiocontrol'] = df_status['controlcount']/df_status['sizedenom']
df_status.sort_values('LLAVE_UN', inplace=True)
df_status['min_id_in_block'] = df_status[['LLAVE_UN','id_scrambled']].groupby('LLAVE_UN').transform('min')
df_status['unique_block'] = 0
df_status.loc[(df_status['min_id_in_block']==df_status['id_scrambled']), 'unique_block'] = 1
df_status['criteria'] = 0
df_status.loc[(df_status['treatcount']>=1)
              & (df_status['controlcount']>=1), 'criteria'] = 1

# Table OA5: Block Descriptives
with open("figs/tableOA5.txt", "w") as text_file:
    print("Num obs: " + str(len(df_status.loc[(df_status['unique_block']==1)
                  & (df_status['criteria']==1)])), file=text_file)
    print(
        df_status.loc[(df_status['unique_block']==1)
                  & (df_status['criteria']==1)][
            ['TotalPredios', 'ratiototal', 'ratiotreat', 'ratiocontrol']].describe(), file=text_file)


df_status['totalpayments'] = df_status[payment_cols[-1]]
df_status['countnum'] = df_status['counts'] - 1
df_status['ratiototalreg'] = df_status['countnum']/df_status['sizedenom']
df_status['decile_ratiototal'] = pd.qcut(df_status['ratiototalreg'], q=10, labels=False)/10


# Table OA6
with open("figs/tableOA6.txt", "w") as text_file:
    reg3_1 = areg('totalpayments ~ ratiotreat',data=df_status.loc[(df_status['assignment_to_treatment']==0)
                                                             & (df_status['criteria']==1)],
                absorb='decile_ratiototal',cluster='LLAVE_UN')
    df_results = reg3_1.params.to_frame('Parameter')
    df_results['Standard Error'] = reg3_1.bse
    df_results['P Values'] = reg3_1.pvalues
    text_file.write(df_results.to_latex(float_format="%.3f"))
    print("N obs:" + str(reg3_1.nobs), file=text_file)
    print("R2:" + str(reg3_1.rsquared), file=text_file)
    print("")

    reg3_2 = areg('totalpayments ~ ever_G1 + assignment_to_treatment + ratiotreat',
                  data=df_status.loc[(df_status['criteria']==1)],
                absorb='decile_ratiototal',cluster='LLAVE_UN')
    df_results = reg3_2.params.to_frame('Parameter')
    df_results['Standard Error'] = reg3_2.bse
    df_results['P Values'] = reg3_2.pvalues
    text_file.write(df_results.to_latex(float_format="%.3f"))
    print("N obs:" + str(reg3_2.nobs), file=text_file)
    print("R2:" + str(reg3_2.rsquared), file=text_file)

